import os 
import sys 
import json 
import numpy as np 
import random 
import h5py 
from tqdm import tqdm 
import webdataset as wds 
import torch 
import torch.nn  as nn 
import torch.nn.functional  as F 
from torch.utils.data  import DataLoader 
from torch.utils.tensorboard  import SummaryWriter 
from scipy.stats  import pearsonr 

sys.path.append("./mindeye2_src")  
from mindeye2_src.utils  import cosine_anneal, soft_clip_loss, batchwise_cosine_similarity, topk, seed_everything 
from mindeye2_src.models import PriorNetwork, BrainDiffusionPrior 
from models import  VoxelVAE,CentralFoveaAttention

 
seed_everything(42) 
  
# =============== 2. 超参数 =============== 
device = torch.device('cuda:0')  
data_path = 'dataset' 
subj = 1
num_session = 40 
batch_size = 32 
test_batch_size = 3000 
num_samples_per_epoch = 750 * num_session 
num_iterations_per_epoch = num_samples_per_epoch // batch_size 
num_epochs = 150 
hidden_dim = 256 
n_blocks = 2 
kl_beta = 1e-4 
loss_scale = 300. 
timesteps = 100 
drop_prob = 0.2 
voxel_vae_path = f"./subj0{subj}_vae_eye/ckpt_300.pt"   # 假设 VAE 训练了 300 个 epoch 
output_dir = f"./subj0{subj}_prior_vae_eye" 
os.makedirs(output_dir,  exist_ok=True) 
 
# =============== 3. 数据集 =============== 
# 3.1 fMRI 
with h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5',  'r') as f: 
    voxels = torch.tensor(f['betas'][:]).float()  
num_voxels = voxels.shape[-1]  
  
# 3.2 CLIP 嵌入 
clip_emb_file = h5py.File(f'{data_path}/clip_embeddings.hdf5',  'r') 
clip_embeddings = clip_emb_file['embeddings']  # (N, 257, 768) 
 
# 3.3 WebDataset 
def my_split_by_node(urls): return urls
train_url = f"{data_path}/wds/subj0{subj}/train/{{0..{num_session-1}}}.tar"
train_data = wds.WebDataset(train_url, resampled=True, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl  = DataLoader(test_data, batch_size=test_batch_size, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)
 
# =============== 4. 初始化模型 =============== 
# 加载预训练的 VAE 模型 
vae = VoxelVAE( 
    num_voxels=num_voxels, 
    token_dim=768, 
    num_tokens=257, 
    hidden_dim=hidden_dim, 
    n_blocks=n_blocks, 
    drop=0.15 
).to(device) 
vae.load_state_dict(torch.load(voxel_vae_path)['model'])  
vae.eval()   # 设置为评估模式 
for param in vae.parameters():
    param.requires_grad = False

 
clip_seq_dim = 257 
clip_emb_dim = 768 
out_dim = clip_emb_dim 
depth = 6 
dim_head = 48 
heads = clip_emb_dim // 52 
 
prior = PriorNetwork( 
    dim=out_dim, 
    depth=depth, 
    dim_head=dim_head, 
    heads=heads, 
    causal=False, 
    num_tokens=clip_seq_dim, 
    learned_query_mode="pos_emb" 
) 
diffusion_prior = BrainDiffusionPrior( 
    net=prior, 
    image_embed_dim=out_dim, 
    condition_on_text_encodings=False, 
    timesteps=timesteps, 
    cond_drop_prob=drop_prob, 
    image_embed_scale=None, 
).to(device) 

# 加载 attention 模块
attn = CentralFoveaAttention(embed_dim=768, grid_size=16).to(device)
attn.load_state_dict(torch.load(voxel_vae_path)['attn'])
attn.eval()  # 必须设为 eval 模式
for param in attn.parameters():
    param.requires_grad = False
 
# 优化器和学习率调度器 
optimizer = torch.optim.AdamW(diffusion_prior.parameters(),  lr=3e-4) 
total_steps = int(np.floor(num_epochs  * num_iterations_per_epoch)) 
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(  
    optimizer, max_lr=3e-4, total_steps=total_steps, 
    final_div_factor=1000, last_epoch=-1, pct_start=2/num_epochs 
) 

# ---- 断点续训 ----
def find_latest_checkpoint(out_dir):
    ckpts = [f for f in os.listdir(out_dir) if f.startswith("ckpt_") and f.endswith(".pt")]
    if not ckpts:
        return None, 0
    latest = max(ckpts, key=lambda x: int(x.split("_")[1].split(".")[0]))
    return os.path.join(out_dir, latest), int(latest.split("_")[1].split(".")[0])

ckpt_path, start_epoch = find_latest_checkpoint(output_dir)
if ckpt_path and os.path.isfile(ckpt_path):
    print(f"🔄  Resuming training from {ckpt_path} (epoch {start_epoch})")
    ckpt = torch.load(ckpt_path, map_location=device)
    diffusion_prior.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    lr_scheduler.load_state_dict(ckpt['scheduler'])
    start_epoch = ckpt['epoch']
else:
    start_epoch = 0
    print("🆕  No checkpoint found, start training from scratch.")
 
# =============== 5. TensorBoard =============== 
tb_dir = os.path.join(output_dir,  "tensorboard") 
os.makedirs(tb_dir,  exist_ok=True) 
writer = SummaryWriter(tb_dir) 
 
# =============== 6. 训练 Prior =============== 
diffusion_prior.train()  
for epoch in range(start_epoch, num_epochs): 
    losses, mse_losses, lrs = [], [], [] 
 
    embedding_iters = torch.zeros(num_iterations_per_epoch, batch_size, clip_seq_dim,clip_emb_dim).float()
    voxel_iters = torch.zeros(num_iterations_per_epoch, batch_size, num_voxels).float()

    iter = -1
    for behav0, _, _, _ in train_dl:
        embedding_idx = behav0[:, 0, 0].cpu().long().numpy()
        embedding0, embedding_sorted_idx = np.unique(embedding_idx, return_index=True)
        if len(embedding0) != len(embedding_idx):
            continue
        iter += 1
        embedding0 = torch.tensor(clip_embeddings[embedding0])
        embedding_iters[iter] = embedding0

        voxel_idx = behav0[:, 0, 5].cpu().long().numpy()
        voxel_sorted_idx = voxel_idx[embedding_sorted_idx]
        voxel0 = voxels[voxel_sorted_idx]
        voxel0 = torch.Tensor(voxel0)
        voxel_iters[iter] = voxel0

        if iter >= num_iterations_per_epoch - 1:
            break
 
    pbar = tqdm(range(num_iterations_per_epoch)) 
    for step in pbar: 
        optimizer.zero_grad()  
        voxel = voxel_iters[step].to(device)

        # CLIP 嵌入
        image_rep = embedding_iters[step].detach().to(device) 
 
        # 使用 VAE 编码 voxel 
        with torch.no_grad():  
            _, _, mu, _ = vae(voxel)
            queried_image_rep = attn(image_rep)   
        # 训练 Prior 
        loss, _ = diffusion_prior(text_embed=queried_image_rep, image_embed=mu) 
        if loss.isnan().any():  
            raise ValueError('NaN loss') 
 
        mse = loss.item()  
        loss *= loss_scale 
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(diffusion_prior.parameters(),  max_norm=2.0) 
        optimizer.step()  
        lr_scheduler.step()  
 
        losses.append(loss.item())  
        mse_losses.append(mse)  
        lrs.append(optimizer.param_groups[0]['lr'])  
        pbar.set_postfix(loss=loss.item(),  mse=mse, lr=lrs[-1]) 
 
        # TensorBoard 实时记录 
        global_step = epoch * num_iterations_per_epoch + step 
        writer.add_scalar("train/loss",  np.mean(losses),  global_step) 
        writer.add_scalar("train/mse",   np.mean(mse_losses),  global_step) 
        writer.add_scalar("train/lr",    lrs[-1], global_step) 
 
    print(f"Epoch {epoch+1:03d}, train loss: {np.mean(losses):.4f},  lr: {np.mean(lrs):.6f},  mse: {np.mean(mse_losses):.4f}")  
    torch.save({
        'model': diffusion_prior.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': lr_scheduler.state_dict(),
        'epoch': epoch + 1,
    }, f"{output_dir}/ckpt_{epoch+1:03d}.pt") 
 
# =============== 7. 收尾 =============== 
clip_emb_file.close()  
writer.close()    